# ------------------------------------------------------------------------------
# Utility functions used throughout the project
# Author: Cassidy Shubatt <cshubatt@gmail.com
# ------------------------------------------------------------------------------

# Setup ------------------------------------------------------------------------
import(data.table)
import(magrittr) # pipes
import(yaml)
import(testit)
import(here)
import(dplyr)
import(Matrix)
import(purrr) # map
import(stats)
import(ggplot2)
import(glue)
import(tidyr) # replace_na
import(xtable)
import(lfe) # clustered mean outcome SEs by ptid
import(broom)
import(tibble) # rownames_to_column
import(stringr) # str_remove
import(readr) # parse_number

a <- modules::use(here::here("lib", "aesthetics.R"))

sparsify <- function(x){
  # recursive sparsify function for large matrices
  if (prod(dim(x)) < (2^31 - 1)) {
    return(Matrix(data.matrix(x), sparse = TRUE))
  }

  halfway <- floor(nrow(x) / 2)
  print("iterating...")
  x_s <- rbind(sparsify(x[1:halfway, ]),
               sparsify(x[(halfway+1):nrow(x), ]))

  return(x_s)
}

safe_left_join <- function(x, y, by = intersect(names(x), names(y))) {
  # left join and ensure rows of x = rows of merged
  xy <- left_join(x, y, by = by, all.x = TRUE)
  assert("Rows of x same pre- and post-merge",
         nrow(xy) == nrow(x))
  xy
}

ntile_within <- function(x, n_tiles, which_x = 1:length(x), showWarnings = TRUE) {
  assert("Indices of which_x in x", all(which_x <= length(x)))
  if(any(is.na(x[which_x]))){
    warning("X contains NA values; establishing tiles on non-NA values.")
  }
  cutpoints <- quantile(x[which_x], 1:(n_tiles-1)/ n_tiles, na.rm = TRUE)
  # add min x and max x as end points
  cutpoints <- c(min(x, na.rm = TRUE), cutpoints, max(x, na.rm = TRUE))
  x_tiled <- cut(
    x, breaks = cutpoints, labels = 1:n_tiles, include.lowest = T
  ) %>%
  as.integer
  assert("Num NAs same in x, x_tiled", sum(is.na(x)) == sum(is.na(x_tiled)))
  return(x_tiled)
}

# functions for creating mean tables by tiles
get_population <- function(outcome, ...) {
  population <- case_when(
    outcome == "stent_tested_or_mace_untested" ~ "all",
    grepl("test_010_day", outcome) ~ "all",
    outcome == "untested_noecg" ~ "all",
    outcome == "untested_not_sameday_tn" ~ "all",
    grepl("stent", outcome) ~ "tested",
    grepl("cabg", outcome) ~ "tested",
    grepl("ami", outcome) ~ "untested",
    TRUE ~ "untested"
  )
  if(
    population == "untested" &
    (!grepl("mace", outcome) & !grepl("death", outcome) &
      ! grepl("ami", outcome))
  ){
    message(
      "Unrecognized outcome variable ", outcome,
      "; assigning population untested"
    )
  }
  return(population)
}

get_xvar_lab <- function(x_var){
  x_var_lab <- case_when(
    grepl("tile_stent", x_var) ~ parse_number(x_var) %>%
      str_pad(3, pad = "0") %>%
      {
        glue("tile_{.}")
      },
    grepl("tile_ensemble_ecg_stent", x_var) ~ parse_number(x_var) %>%
      str_pad(3, pad = "0") %>%
      {
        glue("tile_{.}_ensemble_ecg")
      },
    TRUE ~ x_var
  )
  return(x_var_lab)
}

get_overall_mean <- function(population, outcome, df, ...){
  if (population == "untested") {
    df <- dplyr::filter(df, !test_010_day)
  } else if (population == "tested") {
    df <- dplyr::filter(df, test_010_day)
  }

  overall_mean <- mean(df[[outcome]]) %>%
    round(digits = 3)
  return(overall_mean)
}

get_mean_outcomes <- function(population, outcome, x_var, df, ...){
  # copy so we can use `setnames` w/o altering input df
  df <- copy(df)
  if (population == "untested") {
    df <- dplyr::filter(df, !test_010_day)
  } else if (population == "tested") {
    df <- dplyr::filter(df, test_010_day)
  }

  table_vars <- c(outcome, x_var)
  print(table_vars)
  assert(
    "Outcome, x_var, ptid in table varnames",
    all(c(table_vars, "ptid") %in% names(df))
  )
  message("N = ", nrow(df))

  df <- setnames(df, table_vars, c("outcome", "x_var")) %>%
    mutate(x_var = factor(x_var)) %>%
    select(ptid, outcome, x_var)

  n_df <- df %>%
    group_by(x_var) %>%
    summarize(n = n()) %>%
    ungroup %>%
    mutate(x_var = as.numeric(x_var))

  fit <- felm(outcome ~ x_var + 0| 0 | 0 | ptid, data = df)
  conf <- as.data.frame(confint(fit)) %>%
    rownames_to_column %>%
    setnames(
      c("rowname", "2.5 %", "97.5 %"),
      c("x_var", "beta_lo", "beta_hi")
    )
  tidy_fit <- tidy(fit) %>%
    setnames(
      c("term", "estimate"),
      c("x_var", "beta")
    ) %>%
    select(x_var, beta) %>%
    safe_left_join(conf) %>%
    mutate(x_var = str_remove(x_var, "x_var") %>% as.numeric) %>%
    safe_left_join(n_df) %>%
    mutate(grouping = population)

  return(tidy_fit)
}

get_grouped_mean_outcomes <- function(
  population, outcome, x_var, df, group_var, ...
){
  # copy so we can use `setnames` w/o altering input df
  df <- copy(df)
  if (population == "untested") {
    df <- dplyr::filter(df, !test_010_day)
  } else if (population == "tested") {
    df <- dplyr::filter(df, test_010_day)
  }

  table_vars <- c(outcome, x_var, group_var)
  assert(
    "Outcome, x_var, group_var, ptid in table varnames",
    all(c(table_vars, "ptid") %in% names(df))
  )

  df <- setnames(df, table_vars, c("outcome", "x_var", "group_var")) %>%
    mutate(x_var = factor(x_var)) %>%
    select(ptid, outcome, x_var, group_var)

  n_df <- df %>%
    group_by(x_var, group_var) %>%
    summarize(n = n()) %>%
    ungroup %>%
    mutate(x_var = as.numeric(x_var)) %>%
    setnames("group_var", "grouping")

  fit <- lfe::felm(outcome ~ x_var:group_var + 0| 0 | 0 | ptid, data = df)
  conf <- as.data.frame(confint(fit)) %>%
    rownames_to_column %>%
    setnames(
      c("rowname", "2.5 %", "97.5 %"),
      c("x_var", "beta_lo", "beta_hi")
    )

  tidy_fit <- tidy(fit) %>%
    setnames(
      c("term", "estimate"),
      c("x_var", "beta")
    ) %>%
    select(x_var, beta) %>%
    safe_left_join(conf) %>%
    mutate(grouping = str_remove(x_var, ".*group_var")) %>%
    mutate(
      x_var = str_remove(x_var, ":group_var.*") %>%
        str_remove("x_var") %>%
        as.numeric
    ) %>%
    mutate(grouping = factor(grouping, levels = levels(df$group_var))) %>%
    safe_left_join(n_df)

  return(tidy_fit)
}

clustered_se <- function(data, obs_col_name, cluster_by_col_name = 'ptid') {
	# prep data at sample level
	data   <- data.table(data)
	obs_dt <- copy(data[, c(cluster_by_col_name, obs_col_name), with = FALSE]) # data at the sample level -- each row is a sample
	setnames(obs_dt, c(cluster_by_col_name, obs_col_name), c('id', 'obs'))
	obs_dt <- obs_dt[!is.na(obs), ] # remove rows without valid samples

	# populate formula: SE = (m/n)sqrt(W/((m-1)*m))
	m <- n_id  <- nrow(unique(obs_dt, by = 'id')) # number of 'clusters'
	n <- n_obs <- nrow(obs_dt) # number of total observations
	R <- est_mean <- mean(obs_dt[, obs, ]) # 'raw' mean / estimated mean

	# prep data at cluster level
	cluster_dt <- obs_dt
	cluster_dt[, n_obs_from_id := .N, by = 'id'] #X
	cluster_dt[, sum_obs_from_id := sum(obs), by = 'id'] #Y
	cluster_dt[, mean_obs_from_id := mean(obs), by = 'id']
	cluster_dt <- unique(cluster_dt, by = 'id')
	cluster_dt[, Xsq := (n_obs_from_id)^2, ]
	cluster_dt[, Ysq := (sum_obs_from_id)^2, ]
	cluster_dt[, XY := n_obs_from_id*sum_obs_from_id, ]

	# compute 'W', W = (R^2)(sum(X^2)) - 2(R)(sum(XY)) + sum(Y^2)
	sum_Xsq <- sum(cluster_dt[, Xsq, ])
	sum_Ysq <- sum(cluster_dt[, Ysq, ])
	sum_XY  <- sum(cluster_dt[, XY, ])
	W <- (R*R*sum_Xsq) - (2*R*sum_XY) + sum_Ysq

	# compute SE, SE = (m/n)sqrt(W/((m-1)*m))
	SE <- (m/n)*sqrt(W/((m-1)*m))
	return(SE)
}

str_contains_any <- function(str, substrings){
  contains <- grepl(paste(substrings, collapse = "|"), str)
  return(contains)
}

assign_train_folds <- function(DT, nfolds) {
  stratify_by <- c("test_010_day", "macetrop_pos_or_death_030", "stent_or_cabg_010_day")
  unique_ptids <- DT %>%
    group_by(ptid) %>%
    summarize(
      test_010_day = any(test_010_day),
      macetrop_pos_or_death_030 = any(macetrop_pos_or_death_030),
      stent_or_cabg_010_day = any(stent_or_cabg_010_day)
    ) %>%
    setDT()
  # Permute the data so that fold assignments are randomized
  unique_ptids[, rand := runif(.N)]
  setorder(unique_ptids, rand)

  unique_ptids[, train_fold := 0L]
  unique_ptids[, train_fold := rep_len(1:nfolds, .N),
    by = stratify_by
  ]

  # Rank observations in each fold by strata to enable downsampling
  unique_ptids[, rand_stratum_rank := seq_along(rand) / .N,
    by = c("train_fold", stratify_by)
  ]

  return(unique_ptids)
}

linear_f_stat <- function(full, reduced){
  # takes full and reduced linear model and computes F stat
  sse_r <- deviance(reduced)
  sse_f <- deviance(full)
  df_r <- df.residual(reduced)
  df_f <- df.residual(full)
  f_stat <- ((sse_r - sse_f)/(df_r - df_f))/(sse_f /df_f)

  num_df <- df_r - df_f
  denom_df <- df_f
  p_val <- pf(f_stat, df1 = num_df, df2 = denom_df, lower.tail = FALSE)
  message("F-stat: ", f_stat)
  message("Distribution under the null hypothesis: F(", num_df, ", ", denom_df, ")")
  message("p-value: ", p_val)

  return(f_stat)
}
